R语言机器学习框架tidymodels |
您所在的位置:网站首页 › patchwork R语言 › R语言机器学习框架tidymodels |
我们在构建模型过程中,使用的模型默认的超参数。为了得到稳健准确的模型,需要对模型的超参数进行调优。tidymodels框架提供了tune、rsample、dials等包完成超参数调优。 加载包 library(tidymodels) ## ── Attaching packages ────────────────────────────────────── tidymodels 1.0.0 ── ## ✔ broom 1.0.3 ✔ recipes 1.0.5 ## ✔ dials 1.1.0 ✔ rsample 1.1.1 ## ✔ dplyr 1.1.0 ✔ tibble 3.1.8 ## ✔ ggplot2 3.4.1 ✔ tidyr 1.3.0 ## ✔ infer 1.0.4 ✔ tune 1.0.1 ## ✔ modeldata 1.1.0 ✔ workflows 1.1.3 ## ✔ parsnip 1.0.4 ✔ workflowsets 1.0.0 ## ✔ purrr 1.0.1 ✔ yardstick 1.1.0载入数据 mpe % drop_na() ## Rows: 456 Columns: 14 ## ── Column specification ──────────────────────────────────────────────────────── ## Delimiter: "," ## dbl (14): MPE, Gender, Age, Fever, Cough, ChestPain, WBCPE, LDHS, TPPE, TPPE... ## ## ℹ Use `spec()` to retrieve the full column specification for this data. ## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.数据集划分 set.seed(123) split % collect_metrics() ## # A tibble: 10 × 11 ## trees min_n tree_depth learn_r…¹ loss_r…² .metric .esti…³ mean n std_err ## ## 1 141 34 1 2.12e-3 1.94e- 5 accura… binary 0.738 5 0.0175 ## 2 141 34 1 2.12e-3 1.94e- 5 roc_auc binary 0.5 5 0 ## 3 403 6 7 2.21e-4 1.48e+ 1 accura… binary 0.754 5 0.0188 ## 4 403 6 7 2.21e-4 1.48e+ 1 roc_auc binary 0.784 5 0.0238 ## 5 386 17 10 9.02e-9 3.94e- 1 accura… binary 0.708 5 0.0175 ## 6 386 17 10 9.02e-9 3.94e- 1 roc_auc binary 0.764 5 0.0310 ## 7 461 37 11 8.64e-9 6.47e-10 accura… binary 0.738 5 0.0175 ## 8 461 37 11 8.64e-9 6.47e-10 roc_auc binary 0.5 5 0 ## 9 31 17 12 5.89e-8 1.17e- 2 accura… binary 0.708 5 0.0175 ## 10 31 17 12 5.89e-8 1.17e- 2 roc_auc binary 0.764 5 0.0310 ## # … with 1 more variable: .config , and abbreviated variable names ## # ¹learn_rate, ²loss_reduction, ³.estimator选择最优参数 best_xgb % select_best("accuracy") best_xgb ## # A tibble: 1 × 6 ## trees min_n tree_depth learn_rate loss_reduction .config ## ## 1 403 6 7 0.000221 14.8 Preprocessor1_Model2查看准确率最高模型 show_best(xgb_res,"accuracy") #按照准确率由高到低进行排序 ## # A tibble: 5 × 11 ## trees min_n tree_depth learn_rate loss_r…¹ .metric .esti…² mean n std_err ## ## 1 403 6 7 2.21e-4 1.48e+ 1 accura… binary 0.754 5 0.0188 ## 2 141 34 1 2.12e-3 1.94e- 5 accura… binary 0.738 5 0.0175 ## 3 461 37 11 8.64e-9 6.47e-10 accura… binary 0.738 5 0.0175 ## 4 386 17 10 9.02e-9 3.94e- 1 accura… binary 0.708 5 0.0175 ## 5 31 17 12 5.89e-8 1.17e- 2 accura… binary 0.708 5 0.0175 ## # … with 1 more variable: .config , and abbreviated variable names ## # ¹loss_reduction, ².estimator最终模型 final_wf % finalize_workflow(best_xgb) final_wf ## ══ Workflow ════════════════════════════════════════════════════════════════════ ## Preprocessor: Formula ## Model: boost_tree() ## ## ── Preprocessor ──────────────────────────────────────────────────────────────── ## factor(MPE) ~ . ## ## ── Model ─────────────────────────────────────────────────────────────────────── ## Boosted Tree Model Specification (classification) ## ## Main Arguments: ## trees = 403 ## min_n = 6 ## tree_depth = 7 ## learn_rate = 0.000220724331930274 ## loss_reduction = 14.8435496067182 ## ## Computational engine: xgboost模型在测试集中的表现 final_fit % last_fit(split) final_fit %>% collect_metrics() # A tibble: 2 × 4 .metric .estimator .estimate .config 1 accuracy binary 0.732 Preprocessor1_Model1 2 roc_auc binary 0.729 Preprocessor1_Model1在测试集中准确率为0.732,AUC为0.729。 # 预测值 xgb_pred % collect_predictions() #ROC曲线 xgb_pred %>% roc_curve(`factor(MPE)`,.pred_0) %>% autoplot()更多R语言的知识请关注下方微信公众号【PRLearning】数据统计和机器学习 进行交流学习。公众号后台回复“parsnip”索取代码。如果对您有帮助请转发收藏、点赞、点在看。 参考 资料 1、https://www.tidymodels.org/start/tuning/ 2、https://tune.tidymodels.org/articles/getting_started.html 3、https://www.tmwr.org/performance.html |
今日新闻 |
推荐新闻 |
CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3 |